{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hackathon #9\n",
    "\n",
    "Written by Eleanor Quint\n",
    "\n",
    "Topics: \n",
    "- Gradient Feature Attribution\n",
    "- Integrated Gradient Feature Attribution\n",
    "\n",
    "This is all setup in a IPython notebook so you can run any code you want to experiment with. Feel free to edit any cell, or add some to run your own code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll start with our library imports...\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np                 # to use numpy arrays\n",
    "import tensorflow as tf            # to specify and run computation graphs\n",
    "import tensorflow_datasets as tfds # to load training data\n",
    "import matplotlib.pyplot as plt    # to visualize data and draw plots\n",
    "from tqdm import tqdm              # to track progress of loops\n",
    "import tensorflow_hub as hub       # to load pre-trained models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gradient Feature Attribution\n",
    "\n",
    "We would like to understand what part of an input image a convolutional classifier is \"looking at\" in the process of making a classification. We call this process \"feature attribution\". The goal is to create a heatmap of which pixels in the input image are most important to the model output. The most basic method of doing this is by calculating the gradients of the output with respect to the input image rather than the weights of the model. Where the gradients are largest, the pixels are the most important.\n",
    "\n",
    "We'll load ImageNet-A as the dataset, a set of images labelled with ImageNet labels that were obtained by collecting new data and keeping only those images that ResNet-50 models fail to correctly classify. This should lead to some interesting results in the visualizations. We'll load the Inception model to classify the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imagenet_ds = tfds.load('imagenet_a')\n",
    "test_ds = imagenet_ds['test'].shuffle(1024)\n",
    "\n",
    "model = tf.keras.Sequential([\n",
    "    hub.KerasLayer(\n",
    "        name='inception_v1',\n",
    "        handle='https://tfhub.dev/google/imagenet/inception_v1/classification/4',\n",
    "        trainable=False),\n",
    "])\n",
    "model.build([None, 224, 224, 3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To identify what the ImageNet labels are in text we'll load the labels as an array."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_imagenet_labels(file_path):\n",
    "    labels_file = tf.keras.utils.get_file('ImageNetLabels.txt', file_path)\n",
    "    with open(labels_file) as reader:\n",
    "        f = reader.read()\n",
    "        labels = f.splitlines()\n",
    "    return np.array(labels)\n",
    "\n",
    "imagenet_labels = load_imagenet_labels('https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll define a function specifically to compute gradients of one output class in particular with respect to the model input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_gradients(images, target_class_idx):\n",
    "    with tf.GradientTape() as tape:\n",
    "        # We have to use tape.watch to calculate gradients of a non-variable tensor\n",
    "        tape.watch(images)\n",
    "        logits = model(images)\n",
    "        probs = tf.nn.softmax(logits, axis=-1)[:, target_class_idx]\n",
    "    return tape.gradient(probs, images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we'll get some data, calculate the model prediction, and calculate the gradients with respect to that predicted class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x in test_ds:\n",
    "    # this casts to float, so we'll divide by 255 right away to get [0,1] data\n",
    "    img = tf.image.resize_with_pad(x['image'], 224, 224) / 255.\n",
    "    model_prediction = tf.argmax(model(tf.expand_dims(img, 0)), axis=1)\n",
    "    grads = compute_gradients(tf.expand_dims(img, 0), model_prediction.numpy()[0])\n",
    "    break\n",
    "\n",
    "# clip to positive value gradients and scale\n",
    "grads = tf.clip_by_value(grads, 0, 1)[0]\n",
    "grads = grads / tf.reduce_max(grads)\n",
    "\n",
    "print(\"Predicted Label:\", imagenet_labels[model_prediction.numpy()[0]])\n",
    "print(\"Correct Label:\", imagenet_labels[x['label'].numpy()])\n",
    "plt.imshow(tf.concat([img, grads], axis=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Integrated Gradient Feature Attribution\n",
    "\n",
    "Unfortunately, one gradient calculation doesn't tell the full story. It is only a local calculation at the specific value of the pixel. Integrated gradients improves on this by calculating an integral of the pixel gradients by interpolating the pixel from zero to its true value and collecting all the gradients calculated in the in-between steps.\n",
    "\n",
    "We do this by establishing a baseline and interpolating between the baseline and our image in steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline = tf.zeros_like(img)\n",
    "\n",
    "# Generate uniformly spaced coefficients for interpolating between baseline and image\n",
    "interp_steps=50\n",
    "alphas = tf.linspace(start=0.0, stop=1.0, num=interp_steps+1)\n",
    "\n",
    "def interpolate_images(baseline, image, alphas):\n",
    "    alphas_x = alphas[:, tf.newaxis, tf.newaxis, tf.newaxis]\n",
    "    baseline_x = tf.expand_dims(baseline, axis=0)\n",
    "    input_x = tf.expand_dims(image, axis=0)\n",
    "    delta = input_x - baseline_x\n",
    "    images = baseline_x +  alphas_x * delta\n",
    "    return images\n",
    "\n",
    "interpolated_images = interpolate_images(\n",
    "    baseline=baseline,\n",
    "    image=img,\n",
    "    alphas=alphas)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll use quadrature, [a type of numerical integration](https://en.wikipedia.org/wiki/Numerical_integration), to get a good estimate of the integral."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def integral_approximation(gradients):\n",
    "    # riemann_trapezoidal\n",
    "    grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)\n",
    "    integrated_gradients = tf.math.reduce_mean(grads, axis=0)\n",
    "    return integrated_gradients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the interpolated images and the ability to calculate the integral, we can calculate the integrated gradient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def integrated_gradients(image,\n",
    "                         target_class_idx,\n",
    "                         m_steps=50,\n",
    "                         batch_size=32):\n",
    "    # 1. Generate alphas.\n",
    "    baseline = tf.zeros_like(image)\n",
    "    alphas = tf.linspace(start=0.0, stop=1.0, num=m_steps+1)\n",
    "\n",
    "    # Initialize TensorArray outside loop to collect gradients.    \n",
    "    gradient_batches = tf.TensorArray(tf.float32, size=m_steps+1)\n",
    "\n",
    "    # Iterate alphas range and batch computation for speed, memory efficiency, and scaling to larger m_steps.\n",
    "    for alpha in tqdm(tf.range(0, len(alphas), batch_size)):\n",
    "        from_ = alpha\n",
    "        to = tf.minimum(from_ + batch_size, len(alphas))\n",
    "        alpha_batch = alphas[from_:to]\n",
    "\n",
    "        # 2. Generate interpolated inputs between baseline and input.\n",
    "        interpolated_path_input_batch = interpolate_images(baseline=baseline,\n",
    "                                                           image=image,\n",
    "                                                           alphas=alpha_batch)\n",
    "\n",
    "        # 3. Compute gradients between model outputs and interpolated inputs.\n",
    "        gradient_batch = compute_gradients(images=interpolated_path_input_batch,\n",
    "                                           target_class_idx=target_class_idx)\n",
    "\n",
    "        # Write batch indices and gradients to extend TensorArray.\n",
    "        gradient_batches = gradient_batches.scatter(tf.range(from_, to), gradient_batch)    \n",
    "\n",
    "    # Stack path gradients together row-wise into single tensor.\n",
    "    total_gradients = gradient_batches.stack()\n",
    "\n",
    "    # 4. Integral approximation through averaging gradients.\n",
    "    avg_gradients = integral_approximation(gradients=total_gradients)\n",
    "    \n",
    "    # 5. Scale integrated gradients with respect to input.\n",
    "    integrated_gradients = (image - baseline) * avg_gradients\n",
    "\n",
    "    return integrated_gradients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And finally, plot them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_img_attributions(image,\n",
    "                          target_class_idx,\n",
    "                          m_steps=50,\n",
    "                          cmap=None,\n",
    "                          overlay_alpha=0.4):\n",
    "\n",
    "    attributions = integrated_gradients(image=image,\n",
    "                                        target_class_idx=target_class_idx,\n",
    "                                        m_steps=m_steps)\n",
    "\n",
    "    # Sum of the attributions across color channels for visualization.\n",
    "    # The attribution mask shape is a grayscale image with height and width\n",
    "    # equal to the original image.\n",
    "    attribution_mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1)\n",
    "\n",
    "    fig, axs = plt.subplots(nrows=2, ncols=2, squeeze=False, figsize=(8, 8))\n",
    "\n",
    "    # leave this in as long as the keras VGG-19 preprocessing is being used\n",
    "    image = image[:,:,::-1]\n",
    "    \n",
    "    axs[0, 0].set_title('Baseline image')\n",
    "    axs[0, 0].imshow(tf.zeros_like(image))\n",
    "    axs[0, 0].axis('off')\n",
    "\n",
    "    axs[0, 1].set_title('Original image')\n",
    "    axs[0, 1].imshow(image)\n",
    "    axs[0, 1].axis('off')\n",
    "\n",
    "    axs[1, 0].set_title('Attribution mask')\n",
    "    axs[1, 0].imshow(attribution_mask, cmap=cmap)\n",
    "    axs[1, 0].axis('off')\n",
    "\n",
    "    axs[1, 1].set_title('Overlay')\n",
    "    axs[1, 1].imshow(attribution_mask, cmap=cmap)\n",
    "    axs[1, 1].imshow(image, alpha=overlay_alpha)\n",
    "    axs[1, 1].axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result may or may not make sense to a viewer (and may or may not mean much depending on the problem), but it's at least one mathematical description of how important each pixel is to the classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plot_img_attributions(image=img,\n",
    "                          target_class_idx=x['label'],\n",
    "                          m_steps=128,\n",
    "                          cmap=plt.cm.inferno,\n",
    "                          overlay_alpha=0.4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Thanks to [TensorFlow's documentation](https://www.tensorflow.org/tutorials/interpretability/integrated_gradients) for much of the code used in the Integrated Gradient example.\n",
    "\n",
    "### Homework\n",
    "\n",
    "None! This is the last hackathon. I hope your projects go well and I look forward to the final presentations. Have a great rest of your semester."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl_notebooks",
   "language": "python",
   "name": "dl_notebooks"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
